The AWS sagemaker is a service to support the automatic training for the models. And the price is 1.5x of the normal elastic container. Thus, the distributed learning is important and expensive.
Distributed learning introduction in Pytorch
We need to be aware of what kind of distributed learning we can use, and there are DDP
and Collective communication
from the pytorch documentation (read the documentation for the detail).
Data Parallel
DistributedDataParallel is better than the DataParallel (DP), since DP is limited by the GIL. For DP
, it is to split the dataset into multiple machine, and compute them then reduce them. Suppose you have a forward computation with batch size as 16, and the number of the GPU is 4. Then, you basically calculate batch size 4 in each GPU. To apply it, we just need to add a few code:
if torch.cuda.device_count() > 1:
= nn.DataParallel(model) model
We don’t need to do any other operation to let it run.
Distributed Data Parallel (DDP)
We need to use the specific module to let it work. This trick can overcome the GIL
. A code example can be
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import os
from torch.nn.parallel import DistributedDataParallel as DDP
class SimpleCNN(nn.Module):
def example(rank, world_size):
# create default process group, nccl means running on GPU
"nccl", rank=rank, world_size=world_size)
dist.init_process_group(# create local model and move it to the current device (GPU/CPU)
= SimpleCNN().to(rank)
model # construct DDP model
= DDP(model, device_ids=[rank])
ddp_model # define loss function and optimizer
= nn.CrossEntropyLoss()
loss_fn = optim.SGD(ddp_model.parameters(), lr=0.001)
# forward pass
= ddp_model(torch.randn(64, 1, 28, 28).to(rank)) # Example input size for MNIST
outputs = torch.randint(0, 10, (64,)).to(rank) # Example labels for 64 samples
labels # backward pass
loss_fn(outputs, labels).backward()# update parameters
def main():
= 2
if __name__=="__main__":
# Environment variables for distributed training
"MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ[ main()
and world_size
are two special concept in the distributed learning. When we launch multiple processes to learn the model, the total number of processes is world_size
. For each process, we can define it as rank
. You can imagine rank
is like a small device, so we put the model or data to the rank
like we put them in cuda
We still need to define the env value for the distributed learning code, since the framework needs to setup a communication network.
Use the ZeroRedundancyOptimizer
Since some optimizer like Adam
will keep many states, usually twice the model size, OOM error can occur. Therefore, we consider to deepspeed optimizer. In pytorch, it is already implemented!
from torch.distributed.optim import ZeroRedundancyOptimizer
If we want to use it, just add a flag called use_zero
if use_zero:
= ZeroRedundancyOptimizer(
= torch.optim.Adam(ddp_model.parameters(), lr=0.01) optimizer
This technique is mainly used to distribute the optimizer to multiple machine to avoid the OOM. All other code is similar to the DDP part.
Fully sharded data parallel
This FSDP will distribute the model and data across all process, and it is good especially for the model that can’t be fitted to one GPU. For the example script, refer to this code example.
is a method to execute the distributed learning in a way of elastic running. It can deal with the case that some node may fail. And it can handle the restart automatically.
We should set the checkpoint so that we will at most lose one epoch of training. The code is like
def main():
args = parse_args(sys.argv[1:])
state = load_checkpoint(args.checkpoint_path)
# ensures that this will work
# by exporting all the env vars needed to initialize the process group
for i in range(state.epoch, state.total_num_epochs)
for batch in iter(state.dataset)
train(batch, state.model)
state.epoch += 1
For more usage about the torchrun
, refer to this page. Here is another script that can be runned by the torchrun command. If we want to run the torchrun
, we should firstly make sure the script can adapt to the torchrun
. The code is to run it is:
--rdzv-endpoint=HOST_NODE_ADDR (--arg1 ... train script args...)
For more complicated case, pytorch also provide some solution to use multiple container with communication by docker example or k8s example.